#include "torch_xla/csrc/runtime/pjrt_registry.h"

#include <c10/util/Exception.h>

#include "absl/log/absl_check.h"
#include "absl/log/initialize.h"
#include "xla/pjrt/c/pjrt_c_api.h"
#include "xla/pjrt/cpu/cpu_client.h"
#include "xla/pjrt/distributed/client.h"
#include "xla/pjrt/distributed/distributed.h"
#include "xla/pjrt/distributed/in_memory_key_value_store.h"
#include "xla/pjrt/pjrt_api.h"
#include "xla/pjrt/pjrt_c_api_client.h"

#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/env_vars.h"
#include "torch_xla/csrc/runtime/profiler.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/tf_logging.h"
#include "torch_xla/csrc/runtime/xla_coordinator.h"
#include "torch_xla/csrc/status.h"

namespace torch_xla {
namespace runtime {

namespace {

// Placeholder plugin for testing only. Does not implement multiprocessing or
// configuration. Very likely will not work from Python code.
class LibraryPlugin : public PjRtPlugin {
 public:
  std::string library_path() const override {
    return sys_util::GetEnvString("PJRT_LIBRARY_PATH", "");
  }

  const std::unordered_map<std::string, xla::PjRtValueType>
  client_create_options() const override {
    return {};
  }

  bool requires_xla_coordinator() const override { return false; }
};

std::unordered_map<std::string, std::shared_ptr<const PjRtPlugin>>
    pjrt_plugins_ = {{"LIBRARY", std::make_shared<LibraryPlugin>()}};

absl::StatusOr<std::shared_ptr<const PjRtPlugin>> GetPjRtPlugin(
    const std::string& device_type) {
  auto entry = pjrt_plugins_.find(device_type);
  if (entry == pjrt_plugins_.end()) {
    std::string message = absl::StrCat(
        "No PjRtPlugin registered for: ", device_type,
        ". Make sure the environment variable ", env::kEnvPjRtDevice,
        " is set to a correct device name. See "
        "https://github.com/pytorch/xla/blob/master/docs/source/"
        "contribute/plugins.md for more information on "
        "implementing and registering a new "
        "plugin.");
    return XLA_ERROR_WITH_LOCATION(absl::FailedPreconditionError(message));
  }
  return entry->second;
}

}  // namespace

void RegisterPjRtPlugin(std::string name,
                        std::shared_ptr<const PjRtPlugin> plugin) {
  TF_VLOG(3) << "Registering PjRt plugin " << name;
  pjrt_plugins_[name] = plugin;
}

absl::StatusOr<std::tuple<absl_nonnull std::unique_ptr<xla::PjRtClient>,
                          std::unique_ptr<XlaCoordinator>>>
InitializePjRt(const std::string& device_type) {
  std::unique_ptr<xla::PjRtClient> client;
  std::unique_ptr<XlaCoordinator> coordinator;

  if (sys_util::GetEnvBool(env::kEnvPjrtDynamicPlugins, false) &&
      device_type != "CPU") {
    TF_VLOG(1) << "Initializing client for PjRt plugin " << device_type;
    XLA_ASSIGN_OR_RETURN(std::shared_ptr<const PjRtPlugin> plugin,
                         GetPjRtPlugin(device_type));
    if (plugin) {
      // Init the absl logging to avoid the log spam.
      absl::InitializeLog();

      std::shared_ptr<xla::KeyValueStoreInterface> kv_store = nullptr;
      if (plugin->requires_xla_coordinator()) {
        int local_process_rank = sys_util::GetEnvInt(
            env::kEnvPjRtLocalRank, sys_util::GetEnvInt("LOCAL_RANK", 0));
        int global_process_rank =
            sys_util::GetEnvInt("RANK", local_process_rank);
        int local_world_size =
            sys_util::GetEnvInt(env::kEnvPjRtLocalProcessCount,
                                sys_util::GetEnvInt("LOCAL_WORLD_SIZE", 1));
        int global_world_size =
            sys_util::GetEnvInt("WORLD_SIZE", local_world_size);

        std::string master_addr =
            runtime::sys_util::GetEnvString("MASTER_ADDR", "localhost");
        std::string port = runtime::sys_util::GetEnvString(
            "XLA_COORDINATOR_PORT", XlaCoordinator::kDefaultCoordinatorPort);

        TF_VLOG(3) << "Creating coordinator for rank=" << global_process_rank
                   << ", world size=" << global_world_size
                   << ", coordinator address=" << master_addr << ":" << port;

        // Use the XlaCoordinator as the distributed key-value store.
        XLA_ASSIGN_OR_RETURN(
            coordinator,
            XlaCoordinator::Create(global_process_rank, global_world_size,
                                   master_addr, port));
        std::shared_ptr<xla::DistributedRuntimeClient> distributed_client =
            coordinator->GetClient();
        kv_store = xla::GetDistributedKeyValueStore(distributed_client,
                                                    /*key_prefix=*/"pjrt:");
      }
      XLA_ASSIGN_OR_RETURN(
          const PJRT_Api* c_api,
          pjrt::LoadPjrtPlugin(absl::AsciiStrToLower(device_type),
                               plugin->library_path()));
      XLA_RETURN_IF_ERROR(pjrt::InitializePjrtPlugin(device_type));
      auto create_options = plugin->client_create_options();
      XLA_ASSIGN_OR_RETURN(
          client,
          xla::GetCApiClient(absl::AsciiStrToUpper(device_type),
                             {create_options.begin(), create_options.end()},
                             kv_store));
      profiler::RegisterProfilerForPlugin(c_api);
    }
  } else if (device_type == "CPU") {
    TF_VLOG(1) << "Initializing PjRt CPU client...";
    bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncCpuClient, true);
    int cpu_device_count = sys_util::GetEnvInt(env::kEnvNumCpu, 1);
    XLA_ASSIGN_OR_RETURN(client,
                         xla::GetPjRtCpuClient(async, cpu_device_count));
  } else if (device_type == "TPU") {
    TF_VLOG(1) << "Initializing TFRT TPU client...";
    // Init the absl logging to avoid the log spam.
    absl::InitializeLog();
    // Prefer $TPU_LIBRARY_PATH if set
    auto tpu_library_path = sys_util::GetEnvString(
        env::kEnvTpuLibraryPath,
        sys_util::GetEnvString(env::kEnvInferredTpuLibraryPath, "libtpu.so"));
    XLA_ASSIGN_OR_RETURN(const PJRT_Api* c_api,
                         pjrt::LoadPjrtPlugin("tpu", tpu_library_path));
    XLA_RETURN_IF_ERROR(pjrt::InitializePjrtPlugin("tpu"));
    XLA_ASSIGN_OR_RETURN(client, xla::GetCApiClient("TPU"));
    profiler::RegisterProfilerForPlugin(c_api);
  } else if (device_type == "TPU_LEGACY") {
    return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(
        "TPU_LEGACY client is no longer available."));
  } else if (device_type == "XPU") {
    TF_VLOG(1) << "Initializing PjRt XPU client...";
    XLA_RETURN_IF_ERROR(pjrt::LoadPjrtPlugin(
        "xpu", sys_util::GetEnvString(env::kEnvXpuLibraryPath, "libxpu.so")));
    XLA_ASSIGN_OR_RETURN(client, xla::GetCApiClient("XPU"));
  } else if (device_type == "NEURON") {
    TF_VLOG(1) << "Initializing PjRt NEURON client...";
    XLA_RETURN_IF_ERROR(pjrt::LoadPjrtPlugin(
        "NEURON", sys_util::GetEnvString(env::kEnvNeuronLibraryPath,
                                         "libneuronpjrt.so")));
    XLA_ASSIGN_OR_RETURN(client, xla::GetCApiClient("NEURON"));
  } else {
    return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat(
        "Unknown ", env::kEnvPjRtDevice, ": '", device_type, "'")));
  }

  ABSL_CHECK(client);
  return std::make_tuple(std::move(client), std::move(coordinator));
}

}  // namespace runtime
}  // namespace torch_xla
